#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <fstream>                              // std::ofstream
#include <utility>                              // std::pair

// jScience
#include "jmath_basic.h"                        // sign()
#include "jrandnum.hpp"                         // rand_num_uniform_Mersenne_twister()
#include "jutility.hpp"                         // get_n_rand_unique_elements()

// stats++
#include "statsxx/dataset.hpp"     // DataSet
#include "statsxx/machine_learning/NeuralNet/utility.hpp" // fit_validation_error()


static DataSet prob_to_binary(
                              DataSet ds
                              );


//
// DESC:
//
// INPUT:
//          DataSet ds_tr  :: Training dataset
//          DataSet ds_val :: Validation dataset
//          u              :: learning rate increase factor
//          d              :: learning rate decrease factor
//                            NOTE: M. Reidmiller and H. Braun (developers of Rprop) recommend u = 1.2, d = 0.5 (also used by M. Pfister and R. Rojas in QRprop)
//
// NOTES:
//     - See (for the *improved* version of QRprop):
//          M. Pfister and R. Rojas, "Hybrid Learning Algorithms for Neural Networks -- The adaptive Inclusion of Second Order Information"
//     - QRprop is a batch method
//
inline void NEURAL_NET::QRprop(
                               const int      max_epoch,
                               // -----
                               const bool     early_stopping,
                               // -----
                               const double   gamma_ini,
                               const double   gamma_min,
                               const double   gamma_max,
                               // -----
                               const double   u,
                               const double   d,
                               // -----
                               const DataSet &ds_tr,
                               const DataSet &ds_val
                               )
{
    //static unsigned int max_epoch = 1000;

    //const double u = 1.2;
    //const double d = 0.5;

    // M. Pfister and R. Rojas use the following for gamma_min, gamma_max, gamma_ini
    //const double gamma_min = 0.0003;
    //const double gamma_max = 0.1;
    //const double gamma_ini = 0.01;

    //*********************************************************

    std::vector<double> w;
    for( auto &link : this->m_links )
    {
        w.push_back(link.weight);
    }

    // CALCULATE

    DataSet _ds_tr = prob_to_binary(
                                    ds_tr
                                    );

/*
    DataSet _ds_tr = ds_tr;
*/
    // initialize the prior error gradients to the initial error gradients
    std::vector<double> dEdw1;
    this->calulate_Ederiv(
                          _ds_tr,
                          dEdw1
                          );

    std::vector<double> dEdw2 = dEdw1;

    // set the initial learning rates
    std::vector<double> gamma1;
    for( auto &wi : w )
    {
        gamma1.push_back(gamma_ini);
    }

    // storage (used to set optimal weights)
    std::vector<double> Etr, Etr_stddev, Ev, Ev_stddev;
    std::vector<std::vector<double>> w_store;

    std::ofstream ofs("./training_errs.dat", std::ios::app);

    for( unsigned int k = 0; k < max_epoch; ++k )
    {

        _ds_tr = prob_to_binary(
                                ds_tr
                                );

        DataSet _ds_val = prob_to_binary(
                                         ds_val
                                        );

/*
        _ds_tr = ds_tr;

        DataSet _ds_val = ds_val;
*/
        std::vector<double> dEdw;
        this->calulate_Ederiv( _ds_tr, dEdw );

        for( unsigned int i = 0; i < w.size(); ++i )
        {
            // the following is a concise mix of step 1 (update the individual learning rates) and step 2 (update the weights)
            double dEdw_dEdw1 = dEdw[i]*dEdw1[i];

            if( dEdw_dEdw1 > 0.0 )
            {
                double gamma_i = std::min( (u*gamma1[i]), gamma_max );

                w[i] -= gamma_i*sign(dEdw[i]);

                gamma1[i] = gamma_i;
            }
            else if( dEdw_dEdw1 < 0.0 )
            {
                // double gamma_i = gamma1[i];

                // w[i] = w[i];

                // gamma1[i] = gamma_i;

                dEdw[i] = 0.0;
            }
            else
            {
                double qi;
                if( dEdw[i] != dEdw2[i] )
                {
                    qi = std::max( d, std::min( (1.0/u), std::fabs( dEdw[i]/(dEdw[i] - dEdw2[i]) ) ) );
                }
                else
                {
                    qi = 1.0/u;
                }

                double gamma_i = std::max( (qi*gamma1[i]), gamma_min );

                w[i] -= gamma_i*sign(dEdw[i]);

                gamma1[i] = gamma_i;
            }
        }

        this->assign_weights(
                             w
                             );

        // ***
        double Emean;
        double Estddev;
        std::vector<std::vector<double>> NN_out;

        std::tie( Emean, Estddev, NN_out) = this->parse_TS(
                                                           _ds_tr
                                                           );
        Etr.push_back(Emean);
        Etr_stddev.push_back(Estddev);

        std::tie( Emean, Estddev, NN_out) = this->parse_TS(
                                                           _ds_val
                                                           );
        Ev.push_back(Emean);
        Ev_stddev.push_back(Estddev);

        ofs << (k+1) << "   " << Etr.back() << "   " << Etr_stddev.back() << "   " << Ev.back() << "   " << Ev_stddev.back() << std::endl;

        w_store.push_back(w);
        // ****

        // update all prior quantities
        // note: gamma_{k-1} is updated on the fly
        dEdw2 = dEdw1;
        dEdw1 = dEdw;
    }

    ofs.close();

    if( early_stopping )
    {
        // NOTE: See NOTEs in ::backpropagation().

        std::vector<double>::size_type indx;
        std::vector<double>            f;
        std::tie(
                 f,
                 indx
                 ) = fit_validation_error(
                                          Ev,
                                          Ev_stddev
                                          );

        std::ofstream ofs2("fit_valid_err.dat");

        ofs2 << "using point: " << indx << '\n';

        for( auto i = 0; i < f.size(); ++i )
        {
            ofs2 << (i+1) << "   " << f[i] << '\n';
        }

        ofs2.close();

        assign_weights(w_store[indx]);
    }
    else
    {
        assign_weights(w_store.back());
    }
}

static DataSet prob_to_binary(
                              DataSet ds
                              )
{
    for( auto &pt : ds.pt )
    {
        for( auto i = 0; i < pt.out.size(); ++i )
        {
            double r = rand_num_uniform_Mersenne_twister(0., 1.);

            if( r < pt.out[i] )
            {
                pt.out[i] = 1.;
            }
            else
            {
                pt.out[i] = 0.;
            }
        }
    }

    return ds;
}
